#!/usr/bin/env python3

#==============================================================================
# Functionality
#==============================================================================
import pdb
import sys
import os
import re

# utility funcs, classes, etc go here.

def asserting(cond):
    if not cond:
        pdb.set_trace()
    assert(cond)

def has_stdin():
    return not sys.stdin.isatty()

def reg(pat, flags=0):
    return re.compile(pat, re.VERBOSE | flags)

#==============================================================================
# Cmdline
#==============================================================================
import argparse

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, 
    description="""
TODO
""")
     
parser.add_argument('-v', '--verbose',
    action="store_true",
    help="verbose output" )
     
parser.add_argument('-t', '--timeout',
    type=int,
    default=10,
    help="Session timeout in seconds" )
     
parser.add_argument('tpu_name',
    help="TPU name or address" )

args = None

#==============================================================================
# Main
#==============================================================================

import tensorflow as tf
from tensorflow.contrib import tpu
from tensorflow.contrib.cluster_resolver import TPUClusterResolver
from tensorflow.core.protobuf import config_pb2

def get_tpu_addr(tpu_name=None):
    # Get the TPU's location
    if tpu_name is not None:
      return TPUClusterResolver(tpu_name).get_master()
    if 'COLAB_TPU_ADDR' in os.environ:
      return TPUClusterResolver().get_master()
    elif 'TPU_NAME' in os.environ:
      return TPUClusterResolver(os.environ['TPU_NAME']).get_master()

def run():
    if args.verbose:
        print(args)
    addr = get_tpu_addr(args.tpu_name)
    if not addr:
      print("TPU not found!", args.tpu_name)
      sys.exit(1)
    timeout_in_ms = int(args.timeout*1000)
    config = config_pb2.ConfigProto(operation_timeout_in_ms=timeout_in_ms)
    with tf.compat.v1.Session(addr) as sess:
      print('Initializing TPU', addr)
      sess.run(tpu.initialize_system(), options=config_pb2.RunOptions(timeout_in_ms=timeout_in_ms))
    print('Done')

def main():
    global args
    if not args:
        args, leftovers = parser.parse_known_args()
        args.args = leftovers
    return run()

if __name__ == "__main__":
    main()

